library(tidyverse)
library(cjbart)
library(parallel)
library(haven)
library(pals)
library(sandwich)
library(lmtest)
library(cowplot)
library(xtable)

options(mc.cores = 8)

#### Refresh data ####

source("0_conjoint_functions.R")

conjoint_data <- read_csv("formatted_data/conjoint_data.csv") %>% 
  set_refs()

bin_data <- conjoint_data %>% 
  select(-choice_cont, -weights,
         -ends_with("_sub")) %>%
  filter(!is.na(jobs)) %>% 
  select(-id) %>% 
  mutate(gdp = factor(paste0("gdp_",gdp), 
                      levels = paste0("gdp_",levels(conjoint_data$gdp))),
         jobs = factor(paste0("jobs_", jobs), 
                       levels = paste0("jobs_",levels(conjoint_data$jobs)))) %>% 
  select(-food_indx, -who_indx,
         -starts_with("politics_"),
         -starts_with("health_pol_"),
         -matches("(C|E)[1-9]"),
         -region_merge,
         -attention,
         -system, -POLCONV_VDEM) %>% 
  
  # select(-all_of(c("health_spend_pca","health_compl_pca","intl_flights"))) %>% 
  filter(country != "CHN")

#### Hyperparameter tuning ####

ntrees <- c(100, 250,500,750,1000, 1250)
nfolds <- 5

set.seed(89)

cv_indx <- sample(1:nfolds, size = nrow(bin_data), replace = TRUE)
cv_errors <- rep(NA, length(ntrees))
cv_data <- bin_data %>%
  mutate(across(where(is.character), as.factor)) %>%
  select(-c_id)

cv_errors <- sapply(ntrees, function (ntree) {

  print(ntree)

  test_errs <- lapply(1:nfolds, function (fold) {

    print(paste0("/", fold))

    X_train <- cv_data %>% filter(cv_indx != fold) %>% select(-choice_bin) %>%
      as.data.frame(.)
    X_test <- cv_data %>% filter(cv_indx == fold) %>% select(-choice_bin) %>%
      as.data.frame(.)

    y_train <- cv_data$choice_bin[cv_indx != fold]
    y_test <- cv_data$choice_bin[cv_indx == fold]

    y_pred <- BART::mc.pbart(ntree = ntree,
                             x.train = X_train, y.train = y_train,
                             x.test = X_test,
                             mc.cores = 8)$prob.test.mean

    return(mean((y_test - y_pred)^2))

  }
  )

  return(mean(unlist(test_errs)))

})

tune_results <- data.frame(trees = as.character(ntrees),
                           MSE = cv_errors)

xtable::xtable(tune_results, digits = 4) %>%
  print(., only.contents = TRUE, hline.after = NULL,
        file = "tables/hyper_tuning_results.tex")
